import copy
import re
import time
import datetime

import torch; torch.manual_seed(42)
import torch.nn as nn
import torch.nn.functional as F
import torch.utils
import torch.distributions
import numpy as np
device = 'cuda' if torch.cuda.is_available() else 'cpu'
import os
import copy
import numpy as np
from joblib import dump, load
import sys

import torch
import pickle
from datasets import load_dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    HfArgumentParser,
    TrainingArguments,
    pipeline,
    logging,
)
from peft import LoraConfig, PeftModel
from trl import SFTTrainer
import pandas as pd
from datasets import Dataset
from transformers.trainer_callback import EarlyStoppingCallback
from transformers import TrainerCallback
from sklearn.decomposition import PCA
from joblib import dump, load

import numpy as np
from itertools import combinations
import itertools


if_pca = True

n_components = 400

n_stacked_pca = 4
n_stacked_vae = 128
delta_layer_shape = (1, 2, 4096)
input_shape_pca = (1, 2*n_stacked_pca, 4096)
if if_pca:
    input_shape_vae = (1, 2*n_stacked_vae, n_components)
else:
    input_shape_vae = (1, 2*n_stacked_vae, 4096)
    
full_delta_shape = (128, 2, 4096)
input_size = int(torch.prod(torch.tensor(input_shape_vae), 0))
output_shape = input_shape_vae

latent_dims = 8
hidden_dims = (512, 512)
n_inputs = int(128/n_stacked_vae)
epochs = 300
lr = 1e-4


# The model that you want to train from the Hugging Face hub
model_name = "Llama-2-7b-hf"

################################################################################
# QLoRA parameters
################################################################################

# LoRA attention dimension
lora_r = 2

# Alpha parameter for LoRA scaling
lora_alpha = 8

# Dropout probability for LoRA layers
lora_dropout = 0.1

################################################################################
# bitsandbytes parameters
################################################################################

# Activate 4-bit precision base model loading
use_4bit = True

# Compute dtype for 4-bit base models
bnb_4bit_compute_dtype = "float16"

# Quantization type (fp4 or nf4)
bnb_4bit_quant_type = "nf4"

# Activate nested quantization for 4-bit base models (double quantization)
use_nested_quant = False

################################################################################
# TrainingArguments parameters
################################################################################

# Output directory where the model predictions and checkpoints will be stored
output_dir = "./results"

# Number of training epochs
num_train_epochs = 1

# Enable fp16/bf16 training (set bf16 to True with an A100)
fp16 = False
bf16 = False

# Batch size per GPU for training
per_device_train_batch_size = 1

# Batch size per GPU for evaluation
per_device_eval_batch_size = 1

# Number of update steps to accumulate the gradients for
gradient_accumulation_steps = 1

# Enable gradient checkpointing
gradient_checkpointing = True

# Maximum gradient normal (gradient clipping)
max_grad_norm = 0.3

# Initial learning rate (AdamW optimizer)
learning_rate = 2e-4

# Weight decay to apply to all layers except bias/LayerNorm weights
weight_decay = 0.001

# Optimizer to use
optim = "paged_adamw_32bit"

# Learning rate schedule
lr_scheduler_type = "constant"

# Number of training steps (overrides num_train_epochs)
max_steps = -1

# Ratio of steps for a linear warmup (from 0 to learning rate)
warmup_ratio = 0 #0.03

# Group sequences into batches with same length
# Saves memory and speeds up training considerably
group_by_length = True

# Save checkpoint every X updates steps
save_steps = 0

# Log every X updates steps
logging_steps = 25

################################################################################
# SFT parameters
################################################################################

# Maximum sequence length to use
max_seq_length = None

# Pack multiple short examples in the same input sequence to increase efficiency
packing = False

# Load the entire model on the GPU 0
device_map = {"": 0}



################## HELPER FUNCTIONS #################

def apply_delta(model, layer_names, delta):
#    assert(delta.shape == full_delta_shape)
    with torch.no_grad():
        for i in range(len(layer_names)):
#            print(delta[i].shape)
            model.state_dict()[layer_names[i]].copy_(delta[i])
    return model


def cross_entropy_evaluation(model, tokenizer, test_list, bos="", eos=""):
  with torch.no_grad():
    model.eval()
    total_eval_loss = 0
    tokenized_data = tokenizer([bos+txt+eos for txt in test_list])
    inputs = [(torch.tensor(tokenized_data['input_ids'][i]), torch.tensor(tokenized_data['attention_mask'][i])) for i in range(len(test_list))]
#    print(inputs)
    for b_input_ids, b_masks in inputs:
            b_input_ids = b_input_ids.unsqueeze(0)
            b_masks = b_masks.unsqueeze(0)

            b_labels = b_input_ids

            outputs  = model(b_input_ids,
                            attention_mask = b_masks,
                            labels=b_labels)

            logits = outputs.logits[:, :-1, :]

            labels = b_input_ids[:, 1:].contiguous()
#            print(labels)
#            print(logits.shape)
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1))

#            loss = outputs[0]
#            print(loss)

            batch_loss = torch.mean(loss).item()
            total_eval_loss += batch_loss
    avg_loss = total_eval_loss / len(test_list)
    return avg_loss


def inverse_pca_by_layer(pca_delta, pca_by_layer):
    delta_out = []
    for i in range(len(pca_by_layer)):
        delta_out_layer = torch.tensor(pca_by_layer[i].inverse_transform(pca_delta[i].cpu().detach()).reshape((input_shape_pca[1], 4096)))
        delta_out.append(delta_out_layer)
    delta_out = torch.stack([layer.float() for layer in delta_out], axis=0)
    return delta_out

def apply_pca_by_layer(delta, pca_by_layer):
    delta_pca = []
    for i in range(len(pca_by_layer)):
        delta_pca_layer = torch.tensor(pca_by_layer[i].transform(delta[i].reshape(input_shape_pca[1], -1)).reshape(input_shape_pca[1], -1)).float()
        delta_pca.append(delta_pca_layer)
    delta_pca = torch.stack([layer.float() for layer in delta_pca], axis=0)
    return delta_pca

def stack(delta, n_stacked):
    return torch.stack([torch.vstack(list(delta[i:i+n_stacked])) for i in range(0, 128, n_stacked)], axis=0)

def unstack(delta_stacked, n_stacked):
    delta_unstacked = []
    for i in range(delta_stacked.shape[0]):
        layer = delta_stacked[i].squeeze()
        for k in range(n_stacked):
            delta_unstacked.append(layer[2*k:2*(k+1), :])
    return delta_unstacked

def get_encoder_diff(vae, base_delta, new_delta, preprocess_fn):
    
    base_delta_pca = preprocess_fn(base_delta)
    new_delta_pca = preprocess_fn(new_delta)
    vae.encoder(base_delta_pca)
    base_mu = vae.encoder.mu
    base_sigma = vae.encoder.sigma
    vae.encoder(new_delta_pca)
    new_mu = vae.encoder.mu
    new_sigma = vae.encoder.sigma
    diff_mu = new_mu - base_mu
    diff_sigma = new_sigma - base_sigma
 #   print("ENCODER DIFF:", new_mu, base_mu, new_sigma, base_sigma)
    return base_mu, base_sigma, diff_mu, diff_sigma

def back_to_og_shapes(delta, og_shapes):
    return [delta[i].reshape(og_shapes[i]) for i in range(delta.shape[0])]

def pca_preprocess(raw_delta, pca_by_layer):
    delta_pca_stacked = stack(raw_delta, n_stacked_pca)
    delta_pca = apply_pca_by_layer(delta_pca_stacked.cpu().detach(), pca_by_layer)
    delta_pca_unstacked = unstack(delta_pca, n_stacked_pca)
    return delta_pca_unstacked

def vae_preprocess(delta_pca):
    delta_vae_stacked = stack(delta_pca, n_stacked_vae)
    return delta_vae_stacked

def vae_postprocess(delta_vae_out):
    delta_vae_unstacked = unstack(delta_vae_out, n_stacked_vae)
    return delta_vae_unstacked

def pca_postprocess(delta_vae_unstacked, pca_by_layer):
    delta_out_pca_stacked = stack(delta_vae_unstacked, n_stacked_pca)
    delta_out_unpca = inverse_pca_by_layer(delta_out_pca_stacked, pca_by_layer)
    delta_out = torch.stack(unstack(delta_out_unpca, n_stacked_pca), axis=0)
    return delta_out

class Encoder(nn.Module):
    def __init__(self, input_size, latent_dims, hidden_dims):
        super(Encoder, self).__init__()
        self.linear1 = nn.Linear(input_size, hidden_dims[0])

    def forward(self, x):
        x = torch.flatten(x)#, start_dim=1)
        x = F.relu(self.linear1(x))

        return x

class MultiheadVariationalEncoder(nn.Module):
    def __init__(self, input_size, latent_dims, hidden_dims, n_inputs):
        super(MultiheadVariationalEncoder, self).__init__()
        print("New Encoder")
        self.encoders = nn.ModuleList([Encoder(input_size, latent_dims, hidden_dims) for _ in range(n_inputs)])
        self.linear1 = nn.Linear(input_size, hidden_dims[0])
#        self.linear1 = nn.Linear(hidden_dims[0]*n_inputs, hidden_dims[1])
        self.linear2 = nn.Linear(hidden_dims[0]*n_inputs, latent_dims)
        self.linear3 = nn.Linear(hidden_dims[0]*n_inputs, latent_dims)

        self.N = torch.distributions.Normal(0, 1)
        self.N.loc = self.N.loc # hack to get sampling on the GPU
        self.N.scale = self.N.scale
        self.kl = 0
        self.mu = 0
        self.sigma = 0

    def forward(self, x):
        out = []
        for i in range(x.shape[0]):
            out.append(self.encoders[i](x[i]))
            
        out = torch.flatten(torch.stack(out, axis=0)) 
#        out = F.relu(self.linear1(x))
#        out = F.relu(self.linear1(out))
        mu =  self.linear2(out)
        self.mu = mu
        sigma = torch.exp(self.linear3(out))
        self.sigma = sigma
        z = mu + sigma*self.N.sample(mu.shape)
        self.kl = (sigma**2 + mu**2 - torch.log(sigma) - 1/2).sum()
        return z

class Decoder(nn.Module):
    def __init__(self, output_size, output_shape, latent_dims, hidden_dims, n_inputs):
        super(Decoder, self).__init__()
        self.linear1 = nn.Linear(hidden_dims[0], hidden_dims[0])
#        self.linearx = nn.Linear(hidden_dims[0], hidden_dims[0])
        self.linear2 = nn.Linear(hidden_dims[0], hidden_dims[1])
        self.linear3 = nn.Linear(hidden_dims[1], output_size)

    def forward(self, z):
        z = F.relu(self.linear1(z))
#        z = F.relu(self.linearx(z))
        z = self.linear2(z)
        z = self.linear3(z)
        return z.reshape(output_shape)

class MultiheadDecoder(nn.Module):
    def __init__(self, output_size, output_shape, latent_dims, hidden_dims, n_inputs):
        super(MultiheadDecoder, self).__init__()
        print("New Decoder")
        self.linear1 = nn.Linear(latent_dims, hidden_dims[0]*n_inputs)
        self.decoders = nn.ModuleList([Decoder(output_size, output_shape, latent_dims, hidden_dims, n_inputs) for _ in range(n_inputs)])
        
    def forward(self, z):
        z = F.relu(self.linear1(z))
        z_split = torch.split(z, hidden_dims[0])
        assert len(z_split) == n_inputs
        
        out = []
        for i in range(n_inputs):
            out.append(self.decoders[i](z_split[i]))
        return torch.stack(out, axis=0)

class MultiheadVariationalAutoencoder(nn.Module):
    def __init__(self, input_size, output_shape, latent_dims, hidden_dims, n_inputs):
        super(MultiheadVariationalAutoencoder, self).__init__()
        self.encoder = MultiheadVariationalEncoder(input_size, latent_dims, hidden_dims, n_inputs)
        self.decoder = MultiheadDecoder(input_size, output_shape, latent_dims, hidden_dims, n_inputs)


    def forward(self, x):
        z = self.encoder(x)
        return self.decoder(z)
    

def euclidean_distance(tensor1, tensor2):
    """
    Compute the Euclidean distance between two tensors.
    """
    return np.linalg.norm(tensor1 - tensor2)

